{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import torch\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 384, 384])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = Image.open('../examples/simple/img.jpg')\n",
    "img = transforms.Compose([transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])(img).unsqueeze(0)\n",
    "img.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'pytorch_pretrained_vit' from '/home/luke/projects/experiments/ViT-PyTorch/pytorch_pretrained_vit/__init__.py'>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from importlib import reload\n",
    "import pytorch_pretrained_vit\n",
    "reload(pytorch_pretrained_vit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = pytorch_pretrained_vit.ViT(name='B_16', pretrained=False, num_classes=21843)\n",
    "model = pytorch_pretrained_vit.ViT(name='B_16_imagenet1k', pretrained=False, num_classes=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# list(model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# npz = np.load('imagenet21k_ViT-B_16.npz')\n",
    "npz = np.load('ViT-B_16.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# npz.files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert(npz, state_dict):\n",
    "    new_state_dict = {}\n",
    "    pytorch_k2v = {jax_to_pytorch(k): v for k, v in npz.items()}\n",
    "    for pytorch_k, pytorch_v in state_dict.items():\n",
    "        \n",
    "        # Naming\n",
    "        if 'self_attn.out_proj.weight' in pytorch_k:\n",
    "            v = pytorch_k2v[pytorch_k]\n",
    "            v = v.reshape(v.shape[0] * v.shape[1], v.shape[2])\n",
    "        elif 'self_attn.in_proj_' in pytorch_k:\n",
    "            v = np.stack((pytorch_k2v[pytorch_k + '*q'], \n",
    "                          pytorch_k2v[pytorch_k + '*k'], \n",
    "                          pytorch_k2v[pytorch_k + '*v']), axis=0)\n",
    "        else:\n",
    "            if pytorch_k not in pytorch_k2v:\n",
    "                print(pytorch_k, list(pytorch_k2v.keys()))\n",
    "                assert False\n",
    "            v = pytorch_k2v[pytorch_k]\n",
    "        v = torch.from_numpy(v)\n",
    "        \n",
    "        # Sizing\n",
    "        if '.weight' in pytorch_k:\n",
    "            if len(pytorch_v.shape) == 2:\n",
    "                v = v.transpose(0, 1)\n",
    "            if len(pytorch_v.shape) == 4:\n",
    "                v = v.permute(3, 2, 0, 1)\n",
    "        if ('proj.weight' in pytorch_k):\n",
    "            v = v.transpose(0, 1)\n",
    "            v = v.reshape(-1, v.shape[-1]).T\n",
    "        if ('proj.bias' in pytorch_k):\n",
    "            print(pytorch_k, v.shape)\n",
    "        if ('attn.proj_' in pytorch_k and 'weight' in pytorch_k):\n",
    "            v = v.permute(0, 2, 1)\n",
    "            v = v.reshape(-1, v.shape[-1])\n",
    "        if 'attn.proj_' in pytorch_k and 'bias' in pytorch_k:\n",
    "            v = v.reshape(-1)\n",
    "        new_state_dict[pytorch_k] = v\n",
    "    return new_state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jax_to_pytorch(k):\n",
    "    k = k.replace('Transformer/encoder_norm', 'norm')\n",
    "    k = k.replace('LayerNorm_0', 'norm1')\n",
    "    k = k.replace('LayerNorm_2', 'norm2')\n",
    "    k = k.replace('MlpBlock_3/Dense_0', 'pwff.fc1')\n",
    "    k = k.replace('MlpBlock_3/Dense_1', 'pwff.fc2')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/out', 'proj')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/query', 'attn.proj_q')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/key', 'attn.proj_k')\n",
    "    k = k.replace('MultiHeadDotProductAttention_1/value', 'attn.proj_v')\n",
    "    k = k.replace('Transformer/posembed_input', 'positional_embedding')\n",
    "    k = k.replace('encoderblock_', 'blocks.')\n",
    "    k = 'patch_embedding.bias' if k == 'embedding/bias' else k\n",
    "    k = 'patch_embedding.weight' if k == 'embedding/kernel' else k\n",
    "    k = 'class_token' if k == 'cls' else k\n",
    "    k = k.replace('head', 'fc')\n",
    "    k = k.replace('kernel', 'weight')\n",
    "    k = k.replace('scale', 'weight')\n",
    "    k = k.replace('/', '.')\n",
    "    k = k.lower()\n",
    "    return k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transformer.blocks.0.proj.bias torch.Size([768])\n",
      "transformer.blocks.1.proj.bias torch.Size([768])\n",
      "transformer.blocks.2.proj.bias torch.Size([768])\n",
      "transformer.blocks.3.proj.bias torch.Size([768])\n",
      "transformer.blocks.4.proj.bias torch.Size([768])\n",
      "transformer.blocks.5.proj.bias torch.Size([768])\n",
      "transformer.blocks.6.proj.bias torch.Size([768])\n",
      "transformer.blocks.7.proj.bias torch.Size([768])\n",
      "transformer.blocks.8.proj.bias torch.Size([768])\n",
      "transformer.blocks.9.proj.bias torch.Size([768])\n",
      "transformer.blocks.10.proj.bias torch.Size([768])\n",
      "transformer.blocks.11.proj.bias torch.Size([768])\n"
     ]
    }
   ],
   "source": [
    "new_state_dict = convert(npz, model.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(new_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "\n",
    "def check(M):\n",
    "    labels_map = json.load(open('../examples/simple/labels_map.txt'))\n",
    "    labels_map = [labels_map[str(i)] for i in range(1000)]\n",
    "    with torch.no_grad():\n",
    "        outputs = M(img)\n",
    "    print('-----')\n",
    "    for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():\n",
    "        prob = torch.softmax(outputs, dim=1)[0, idx].item()\n",
    "        print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----\n",
      "[388] giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca           (99.51%)\n",
      "[387] lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens         (0.16%)\n",
      "[297] sloth bear, Melursus ursinus, Ursus ursinus                                 (0.05%)\n",
      "[295] American black bear, black bear, Ursus americanus, Euarctos americanus      (0.03%)\n",
      "[296] ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus                 (0.03%)\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "check(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def printhook(self, input, output):\n",
    "    print('Inside ' + self.__class__.__name__ + ' forward')\n",
    "    print('input: ', type(input))\n",
    "    print('input[0]: ', type(input[0]))\n",
    "    print('input size:', input[0].size())\n",
    "    print('input norm:', input[0].norm())\n",
    "    if isinstance(output, tuple):\n",
    "        output = output[0]\n",
    "    print('output size:', output.data.size())\n",
    "    print('output norm:', output.data.norm())\n",
    "    print('-----------\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# h1 = m.blocks[0].attn.proj.register_forward_hook(printhook)  # m.blocks[0].register_forward_hook(printhook) \n",
    "# h2 = model.transformer.blocks[0].proj.register_forward_hook(printhook)  # model.transformer.layers[0].register_forward_hook(printhook) \n",
    "# m(img)\n",
    "# model(img)\n",
    "# h1.remove()\n",
    "# h2.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
